package com.lm.deeplearning4j.service;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.stereotype.Service;

@Service
public class PredictService {

    private final MultiLayerNetwork model;

    public PredictService(MultiLayerNetwork model) {
        this.model = model;
    }

    /**
     * 模拟输入并预测类别
     */
    public String predict(double[] features) {
        INDArray input = Nd4j.create(features).reshape(1, 4);;
        INDArray output = model.output(input);
        int classIdx = Nd4j.argMax(output, 1).getInt(0);
        return "Predicted class: " + classIdx + ", Output=" + output;
    }
}
